############################################
# Illustration:
# Interaction and interaction controls
############################################

library(ggplot2)
library(scales)
library(viridis)

set.seed(12345)

# data points per group
n <- 700

# group variable
kids <- rbinom(n, 1, .5)

# predictor
vacation <- rnorm(n)

# confounder
work <- -2*kids + rnorm(n)

# outcome
relaxation <- -3*kids + (1 - kids)*1*vacation - 0.3*work + rnorm(n, sd = 1)

# Put into data frame
dat <- data.frame(vacation, relaxation, kids, work)

# Set up function that will later return color values for manual plotting of regression lines
viridis_fun <- col_numeric(palette = viridis(256), domain = c(min(dat$work), max(dat$work)))

# Set up function that we will later use to generate model predictions
predictions <- function(model, vacation, kids, work) {
  # function that takes a fitted model and all predictor values
  # and returns the prediction
  newdat <- data.frame(vacation = vacation, kids = kids, work = work)
  return(as.numeric(predict(model, newdat)))
}

# Alpha when points and lines overlap
overlap <- .5

#####################
# Just the data points
#####################
ggplot(data = dat) +
  geom_point(aes(x = vacation, y = relaxation, color = work)) +
  theme_classic() +
  scale_color_viridis_c() +
  theme(legend.position = "none",
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank()) +
  xlab("Days on vacation") +
  ylab("Relaxation")
ggsave("00_data.png", width = 6, height = 4.5)

#####################
# Initial (wrong) interaction analysis
#####################
model_1 <- lm(relaxation ~ vacation*work)

# Generate predictions for average and plus minus 1 SD work
pred_data_1 <- data.frame(matrix(NA, nrow = 100, ncol = 4))
names(pred_data_1) <- c("vacation", "low", "mid", "high")
pred_data_1$vacation <- seq(from = min(dat$vacation), to = max(dat$vacation), length.out = 100)
pred_data_1$low <- predictions(model = model_1, vacation = pred_data_1$vacation, kids = 0, work = mean(dat$work) - sd(dat$work))
pred_data_1$mid <- predictions(model = model_1, vacation = pred_data_1$vacation, kids = 0, work = mean(dat$work))
pred_data_1$high <- predictions(model = model_1, vacation = pred_data_1$vacation, kids = 0, work = mean(dat$work) + sd(dat$work))

# Plot the results
ggplot(data = dat) +
  geom_point(aes(x = vacation, y = relaxation, color = work),
             alpha = overlap) +
  theme_classic() +
  scale_color_viridis_c() +
  theme(legend.position = "none",
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank()) +
  xlab("Days on vacation") +
  ylab("Relaxation") +
  geom_line(data = pred_data_1, aes(x = vacation, y = low), color = viridis_fun(mean(dat$work) - sd(dat$work))) +
  geom_line(data = pred_data_1, aes(x = vacation, y = mid), color = viridis_fun(mean(dat$work))) +
  geom_line(data = pred_data_1, aes(x = vacation, y = high), color = viridis_fun(mean(dat$work) + sd(dat$work)))
ggsave("01_model.png", width = 6, height = 4.5)


#####################
# Highlight the different groups
#####################
ggplot(data = dat) +
  geom_point(aes(x = vacation, y = relaxation, color = work, shape = as.factor(kids))) +
  theme_classic() +
  scale_color_viridis_c() +
  scale_shape_manual(values = c(3, 16)) +
  theme(legend.position = "none",
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank()) +
  xlab("Days on vacation") +
  ylab("Relaxation")
ggsave("02_groups.png", width = 6, height = 4.5)



#####################
# Add main effect of kids
#####################
model_2 <- lm(relaxation ~ vacation*work + kids)

# Generate predictions for average and plus minus 1 SD work and with/without kids
pred_data_2 <- data.frame(matrix(NA, nrow = 100, ncol = 7))
names(pred_data_2) <- c("vacation", "low_no_kids", "mid_no_kids", "high_no_kids",
                        "low_kids", "mid_kids", "high_kids")

pred_data_2$vacation <- seq(from = min(dat$vacation), to = max(dat$vacation), length.out = 100)
pred_data_2$low_no_kids <- predictions(model = model_2, vacation = pred_data_2$vacation, kids = 0, work = mean(dat$work) - sd(dat$work))
pred_data_2$mid_no_kids <- predictions(model = model_2, vacation = pred_data_2$vacation, kids = 0, work = mean(dat$work))
pred_data_2$high_no_kids <- predictions(model = model_2, vacation = pred_data_2$vacation, kids = 0, work = mean(dat$work) + sd(dat$work))
pred_data_2$low_kids <- predictions(model = model_2, vacation = pred_data_2$vacation, kids = 1, work = mean(dat$work) - sd(dat$work))
pred_data_2$mid_kids <- predictions(model = model_2, vacation = pred_data_2$vacation, kids = 1, work = mean(dat$work))
pred_data_2$high_kids <- predictions(model = model_2, vacation = pred_data_2$vacation, kids = 1, work = mean(dat$work) + sd(dat$work))


# Plot the results
ggplot(data = dat) +
  geom_point(aes(x = vacation, y = relaxation, color = work, shape = as.factor(kids)),
             alpha = overlap) +
  theme_classic() +
  scale_color_viridis_c() +
  scale_shape_manual(values = c(3, 16)) +
  theme(legend.position = "none",
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank()) +
  xlab("Days on vacation") +
  ylab("Relaxation") +
  geom_line(data = pred_data_2, linetype = "longdash", aes(x = vacation, y = low_no_kids), color = viridis_fun(mean(dat$work) - sd(dat$work))) +
  geom_line(data = pred_data_2, linetype = "longdash", aes(x = vacation, y = mid_no_kids), color = viridis_fun(mean(dat$work))) +
  geom_line(data = pred_data_2, linetype = "longdash", aes(x = vacation, y = high_no_kids), color = viridis_fun(mean(dat$work) + sd(dat$work))) +
  geom_line(data = pred_data_2, aes(x = vacation, y = low_kids), color = viridis_fun(mean(dat$work) - sd(dat$work))) +
  geom_line(data = pred_data_2, aes(x = vacation, y = mid_kids), color = viridis_fun(mean(dat$work))) +
  geom_line(data = pred_data_2, aes(x = vacation, y = high_kids), color = viridis_fun(mean(dat$work) + sd(dat$work)))

ggsave("03_model.png", width = 6, height = 4.5)


#####################
# Correct interaction plot
#####################

model_3 <- lm(relaxation ~ vacation*work + vacation*kids)

# Generate predictions for average and plus minus 1 SD work and with/without kids
pred_data_3 <- data.frame(matrix(NA, nrow = 100, ncol = 7))
names(pred_data_3) <- c("vacation", "low_no_kids", "mid_no_kids", "high_no_kids",
                        "low_kids", "mid_kids", "high_kids")

pred_data_3$vacation <- seq(from = min(dat$vacation), to = max(dat$vacation), length.out = 100)
pred_data_3$low_no_kids <- predictions(model = model_3, vacation = pred_data_3$vacation, kids = 0, work = mean(dat$work) - sd(dat$work))
pred_data_3$mid_no_kids <- predictions(model = model_3, vacation = pred_data_3$vacation, kids = 0, work = mean(dat$work))
pred_data_3$high_no_kids <- predictions(model = model_3, vacation = pred_data_3$vacation, kids = 0, work = mean(dat$work) + sd(dat$work))
pred_data_3$low_kids <- predictions(model = model_3, vacation = pred_data_3$vacation, kids = 1, work = mean(dat$work) - sd(dat$work))
pred_data_3$mid_kids <- predictions(model = model_3, vacation = pred_data_3$vacation, kids = 1, work = mean(dat$work))
pred_data_3$high_kids <- predictions(model = model_3, vacation = pred_data_3$vacation, kids = 1, work = mean(dat$work) + sd(dat$work))


# Plot the results
ggplot(data = dat) +
  geom_point(aes(x = vacation, y = relaxation, color = work, shape = as.factor(kids)),
             alpha = overlap) +
  theme_classic() +
  scale_color_viridis_c() +
  scale_shape_manual(values = c(3, 16)) +
  theme(legend.position = "none",
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.x = element_blank(),
        axis.ticks.y = element_blank()) +
  xlab("Days on vacation") +
  ylab("Relaxation") +
  geom_line(data = pred_data_3, linetype = "longdash", aes(x = vacation, y = low_no_kids), color = viridis_fun(mean(dat$work) - sd(dat$work))) +
  geom_line(data = pred_data_3, linetype = "longdash", aes(x = vacation, y = mid_no_kids), color = viridis_fun(mean(dat$work))) +
  geom_line(data = pred_data_3, linetype = "longdash", aes(x = vacation, y = high_no_kids), color = viridis_fun(mean(dat$work) + sd(dat$work))) +
  geom_line(data = pred_data_3, aes(x = vacation, y = low_kids), color = viridis_fun(mean(dat$work) - sd(dat$work))) +
  geom_line(data = pred_data_3, aes(x = vacation, y = mid_kids), color = viridis_fun(mean(dat$work))) +
  geom_line(data = pred_data_3, aes(x = vacation, y = high_kids), color = viridis_fun(mean(dat$work) + sd(dat$work)))

ggsave("04_model.png", width = 6, height = 4.5)








